{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import onnxruntime\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_path = 'checkpoint'\n",
    "tags_path = os.path.join(input_path, 'tags.txt')\n",
    "model_path = os.path.join(input_path, 'model.onnx')\n",
    "generator_path = os.path.join(input_path, 'Gs.pth')\n",
    "device =  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "batch_size = 4\n",
    "seed = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's run one image to checkout if it works\n",
    "C = onnxruntime.InferenceSession(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(tags_path, 'r') as tags_stream:\n",
    "    tags = np.array([tag for tag in (tag.strip() for tag in tags_stream) if tag])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import stylegan2\n",
    "from stylegan2 import utils\n",
    "\n",
    "G = stylegan2.models.load(generator_path, map_location=device)\n",
    "G.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_image_tensor(image_tensor, pixel_min=-1, pixel_max=1):\n",
    "    if pixel_min != 0 or pixel_max != 1:\n",
    "        image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)\n",
    "    return image_tensor.clamp(min=0, max=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "\n",
    "qlatents = torch.randn(1, G.latent_size).to(device=device, dtype=torch.float32)\n",
    "generated = G(qlatents)\n",
    "images = to_image_tensor(generated)\n",
    "# 299 is the input size of the model\n",
    "images = F.interpolate(images, size=(299, 299), mode='bilinear')\n",
    "ort_inputs = {C.get_inputs()[0].name: images.detach().cpu().numpy()}\n",
    "[predicted_labels] = C.run(None, ort_inputs)\n",
    "# print out some tags\n",
    "plt.imshow(images[0].detach().cpu().permute(1, 2, 0))\n",
    "labels = [tags[i] for i, score in enumerate(predicted_labels[0]) if score > 0.5]\n",
    "print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset seed\n",
    "torch.manual_seed(seed)\n",
    "iteration = 5000\n",
    "\n",
    "progress = utils.ProgressWriter(iteration)\n",
    "progress.write('Generating images...', step=False)\n",
    "\n",
    "qlatents_data = torch.Tensor(0, G.latent_size).to(device=device, dtype=torch.float32)\n",
    "dlatents_data = torch.Tensor(0, 16, G.latent_size).to(device=device, dtype=torch.float32)\n",
    "labels_data = torch.Tensor(0, len(tags)).to(device=device, dtype=torch.float32)\n",
    "for i in range(iteration):\n",
    "    qlatents = torch.randn(batch_size, G.latent_size).to(device=device, dtype=torch.float32)\n",
    "    with torch.no_grad():\n",
    "        generated, dlatents = G(latents=qlatents, return_dlatents=True)\n",
    "        # inplace to save memory\n",
    "        generated = to_image_tensor(generated)\n",
    "        # 299 is the input size of the model\n",
    "        # resize the image to 299 * 299\n",
    "        images = F.interpolate(generated, size=(299, 299), mode='bilinear')\n",
    "        labels = []\n",
    "        ## tagger does not take input as batch, need to feed one by one\n",
    "        for image in images:\n",
    "            ort_inputs = {C.get_inputs()[0].name: image.reshape(1, 3, 299, 299).detach().cpu().numpy()}\n",
    "            [[predicted_labels]] = C.run(None, ort_inputs)\n",
    "            labels.append(predicted_labels)\n",
    "        # store the result\n",
    "        labels_tensor = torch.Tensor(labels).to(device=device, dtype=torch.float32)\n",
    "        qlatents_data = torch.cat((qlatents_data, qlatents))\n",
    "        dlatents_data = torch.cat((dlatents_data, dlatents))\n",
    "        labels_data = torch.cat((labels_data, labels_tensor))\n",
    "\n",
    "        progress.step()\n",
    "\n",
    "progress.write('Done!', step=False)\n",
    "progress.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save({\n",
    "    'qlatents_data': qlatents_data.cpu(),\n",
    "    'dlatents_data': dlatents_data.cpu(),\n",
    "    'labels_data': labels_data.cpu(),\n",
    "    'tags': tags\n",
    "}, 'latents.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
